-
Notifications
You must be signed in to change notification settings - Fork 12.7k
server: implement GLM-style MTP #15225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This is correct - we always alternate between conventional and speculative passes. It's definitely not optimal, but improves flexibility for regular sampling. It allows to change the speculative parameters and even disable it per request, while the logic is quite simple. It should be possible to improve this by keeping track which slots are speculating on each iteration and skip adding tokens to the conventional batch for them. It might be a good idea to implement this separately to avoid huge changes in the logic in a single PR. |
Generally we should try to minimize the changes to On first look, I think the path that involves minimal changes is:
Extracting the MTP logits during
Currently, I am not sure which way is better. The first requires a new API call, while the second might break some existing assumptions (not sure if that's the case yet). In any case, you can avoid this until you get the implementation working with a reasonable speedup. After that, we can discuss further how to best refactor the implementation. |
I don't see an issue with adding a new API for this, and it would be easier to use. |
Out of curiosity, is the API for this expected to be flexible enough that we could jump off of it to add things like Medusa / Eagle style (or IBM Accelerator) self speculative decoding heads? I'm pretty sure they work fairly similarly (depending on the final output embeddings of the current token). Another note: After some consideration I think the expected speedup of the MTP module will depend a lot on the hardware the model's running on, particularly because it's an MoE model. While the next token prediction depends only on the current state, if we're doing self speculative decoding, that's additional forward passes. Those forward passes aren't guaranteed to have the same expert usage patterns, meaning the speedup should be some function of the tokens predicted and the expert re-use coefficient for the tokens verified. So, just noting that if it's implemented and there's not a 2x or 3x increase in T/s, it may not be a skill issue on the part of a contributor, but due to the mathematical nature of the calculation. For people running franken setups with Attention / KV Cache on GPU and MoE FFNs on CPU, it's possible that using previously unused experts in the verification sweep may result in a weird situation where the parallel verification process is actually memory bandwidth bound. Not to discourage the implementation of this, I just wanted to give a heads up so nobody's dejected if the theoretical speedups can't be hit. There should still be at least some speedup, though. |
This is very much a draft/proof of concept I'm playing with, just one idea for an MTP implementation. Planning to test on GLM-4.5 because it's the only model out there that we've preserved NextN tensors for.
From what I can tell
So implementation-wise it seems like
mtp_speculative_gen_draft
in speculative.cpp that is vastly simplified and branch into it in server.cpp when a slot has MTP (versuscommon_speculative_gen_draft
).ctx_dft
in this case as well. It's a bit hacky but I was thinking we could just havectx_dft = ctx
and then have both normal and MTP passes write over the sharedctx
logits. I think this minimizes required code changes elsewhereThis is my first time (1) working with ML stuff outside of python (2) attempting to contribute, so patience is appreciated :)